//
// Created by chonghao on 3/10/21.
//

#include "nanovoid_app.h"
#include <cstring>

// using namespace std;

NanoVoidNormal::NanoVoidNormal(uint _size, ParameterSet _p) : p(_p) {
    this->value_table_size = _size * _size * 3;
    this->num_items = _size * _size;
    this->size = _size;

    this->old_v = new valueType[this->value_table_size];
    this->new_v = new valueType[this->value_table_size];
}

NanoVoidNormal::~NanoVoidNormal() {
    delete this->old_v;
    delete this->new_v;
}

void NanoVoidNormal::grab_vals(uint item, valueType *value_table, valueType *vals) {
    Coordinate2d3c c(0, 0);
    c.from_item(item, size, size * size);

    for (uint i = 0; i < lap_len_2nd; i++) {
        Coordinate2d3c cc(c);

        int x = cc.x + dx[i];
        int y = cc.y + dy[i];

        x = max(x, 0);
        x = min(x, size - 1);
        y = max(y, 0);
        y = min(y, size - 1);

        cc.x = x;
        cc.y = y;

        uint this_item_c1 = cc.to_item_c1(size, size * size);
//        uint this_item_c2 = cc.to_item_c2(size, size * size);
//        uint this_item_c3 = cc.to_item_c3(size, size * size);
        // b here
        // uint pd1 = inv.item2pd[this_item_c1];
        // uint root1 = inv.find_(pd1);
        // uint root_item1 = inv.d_item(root1);

        vals[i] = value_table[this_item_c1];

//        uint pd2 = inv.item2pd[this_item_c2];
//        uint root2 = inv.find_(pd2);
//        uint root_item2 = inv.d_item(root2);
        vals[i + lap_len_2nd] = value_table[this_item_c1 + num_items];

//        uint pd3 = inv.item2pd[this_item_c3];
//        uint root3 = inv.find_(pd3);
//        uint root_item3 = inv.d_item(root3);
        vals[i + lap_len_2nd * 2] = value_table[this_item_c1 + num_items * 2];
    }
}

void NanoVoidNormal::forward_one_step(valueType *vals, uint c, valueType *new_v) {
    valueType energy_v = std::abs(p.energy_v0) + 0.001;
    valueType energy_i = std::abs(p.energy_i0) + 0.001;
    valueType kBT = std::abs(p.kBT0) + 0.001;
    valueType kappa_v = std::abs(p.kappa_v0) + 0.001;
    valueType kappa_i = std::abs(p.kappa_i0) + 0.001;
    valueType kappa_eta = std::abs(p.kappa_eta0) + 0.001;
    valueType r_bulk = std::abs(p.r_bulk0) + 0.001;
    valueType r_surf = std::abs(p.r_surf0) + 0.001;

    valueType p_casc = std::abs(p.p_casc0) + 0.001;
    valueType bias = std::abs(p.bias0) + 0.001;
    valueType vg = std::abs(p.vg0) + 0.001;
    valueType diff_v = std::abs(p.diff_v0) + 0.001;
    valueType diff_i = std::abs(p.diff_i0) + 0.001;
    valueType L = std::abs(p.L0) + 0.001;

    // compute cv, ci
    valueType h_dfs_dcv[lap_len_1st];
    valueType h_dfs_dci[lap_len_1st];

    // construct h_dfs_dcv, h_dfs_dci
    for (uint i = 0; i < lap_len_1st; i++) {
        // h_dfs_dcv[i] = (vals[lap_len_2nd * 2 + i] - 1) * (vals[lap_len_2nd * 2 + i] - 1); // (eta-1)**2
        // h_dfs_dci[i] = h_dfs_dcv[i];
        h_dfs_dcv[i] = 1.0;
        h_dfs_dci[i] = 1.0;

        valueType log_cv = log_with_mask_single(vals[i], EPS);
        valueType log_ci = log_with_mask_single(vals[lap_len_2nd + i], EPS);
        valueType log_1_cv_ci = log_with_mask_single(1 - vals[i] - vals[i + lap_len_2nd], EPS);

        h_dfs_dcv[i] = h_dfs_dcv[i] * (energy_v + kBT * (log_cv - log_1_cv_ci));
        h_dfs_dci[i] = h_dfs_dci[i] * (energy_i + kBT * (log_ci - log_1_cv_ci));
        if ((1 - vals[i] - vals[i + lap_len_2nd]) < EPS) {
            h_dfs_dcv[i] = 0.0;
            h_dfs_dci[i] = 0.0;
        }

        h_dfs_dcv[i] = h_dfs_dcv[i] * (vals[lap_len_2nd * 2 + i] - 1) * (vals[lap_len_2nd * 2 + i] - 1); // (eta-1)**2
        h_dfs_dci[i] = h_dfs_dci[i] * (vals[lap_len_2nd * 2 + i] - 1) * (vals[lap_len_2nd * 2 + i] - 1);
    }

    valueType j_dfv_dcv[lap_len_1st];
    valueType j_dfv_dci[lap_len_1st];

    for (uint i = 0; i < lap_len_1st; i++) {
        j_dfv_dcv[i] = vals[lap_len_2nd * 2 + i] * vals[lap_len_2nd * 2 + i]; // eta**2
        j_dfv_dci[i] = j_dfv_dcv[i];

        j_dfv_dcv[i] = j_dfv_dcv[i] * 2 * (vals[i] - 1);
        j_dfv_dci[i] = j_dfv_dci[i] * 2 * vals[lap_len_2nd + i];
    }

    valueType dt = 2e-2;
    valueType mv = diff_v * vals[0] / kBT;
    valueType mi = diff_i * vals[lap_len_2nd] / kBT;

    valueType dt_mv_lap_h_dfs_dcv = dt * mv * inner_product(h_dfs_dcv, lapw, lap_len_1st);
    valueType dt_mv_lap_j_dfv_dcv = dt * mv * inner_product(j_dfv_dcv, lapw, lap_len_1st);
    valueType dt_mv_lap_lap_cv = - dt * mv * inner_product(vals, laplapw, lap_len_2nd);

    new_v[c] = vals[0] + dt_mv_lap_h_dfs_dcv + dt_mv_lap_j_dfv_dcv + kappa_v * dt_mv_lap_lap_cv;

    valueType dt_mi_lap_h_dfs_dci = dt * mi * inner_product(h_dfs_dci, lapw, lap_len_1st);
    valueType dt_mi_lap_j_dfv_dci = dt * mi * inner_product(j_dfv_dci, lapw, lap_len_1st);
    valueType dt_mi_lap_lap_ci = - dt * mi * inner_product(vals + lap_len_2nd, laplapw, lap_len_2nd);

    new_v[c + num_items] = vals[lap_len_2nd] + dt_mi_lap_h_dfs_dci + dt_mi_lap_j_dfv_dci + kappa_i * dt_mi_lap_lap_ci;

    // compute eta
    // fs
    valueType fs = energy_v * vals[0] + energy_i * vals[lap_len_2nd];
    fs = fs + kBT * (vals[0] * log_with_mask_single(vals[0], EPS));
    fs = fs + kBT * (vals[lap_len_2nd] * log_with_mask_single(vals[lap_len_2nd], EPS));
    fs = fs + kBT * ((1 - vals[0] - vals[lap_len_2nd]) * log_with_mask_single(1 - vals[0] - vals[lap_len_2nd], EPS));
    if ((1 - vals[0] - vals[lap_len_2nd]) < EPS) {
        fs = 0.0;
    }
    // fv
    valueType fv = (vals[0] - 1) * (vals[0] - 1) + vals[lap_len_2nd] * vals[lap_len_2nd];

    valueType dF_deta = N * (fs * 2 * (vals[lap_len_2nd*2] - 1) + fv * 2 * vals[lap_len_2nd*2] - \
            kappa_eta * inner_product(vals + lap_len_2nd*2, lapw, lap_len_1st));

    if (this->verbose == 1) {
        printf("df_deta: %.8f\n", dF_deta);
    }

    if (this->verbose == 1) {
        valueType temp_eta = vals[lap_len_2nd*2] + dt * (-L) * dF_deta;
        printf("new eta: %.8f\n", temp_eta);
    }

    new_v[c + num_items * 2] = vals[lap_len_2nd*2] + dt * (-L) * dF_deta;

    if (std::signbit(new_v[c])) {
        new_v[c] = 0.0; //-new_v[c];
    }

    if (std::signbit(new_v[c + num_items])) {
        new_v[c + num_items] = 0.0; // -new_v[c + num_items];
    }

    if (std::signbit(new_v[c + num_items * 2])) {
        new_v[c + num_items * 2] = 0.0; // -new_v[c + num_items * 2];
    }

    if (new_v[c] >= 1.0) {
        new_v[c] = 1.0; //-new_v[c];
    }

    if (new_v[c + num_items] >= 1.0) {
        new_v[c + num_items] = 1.0; // -new_v[c + num_items];
    }

    if (new_v[c + num_items * 2] >= 1.0) {
        new_v[c + num_items * 2] = 1.0; // -new_v[c + num_items * 2];
    }

    if (this->verbose == 2) {
        printf("vals: ");
        for (int itr = 0; itr < lap_len_2nd * 3; itr++) {
            printf("%.2f, ", vals[itr]);
        }
        printf(" get cv=%.2f, ci=%.2f, eta=%.2f\n", new_v[c], new_v[c + num_items], new_v[c + num_items*2]);
    }
}

void NanoVoidNormal::log_with_mask(valueType *mat, valueType eps, uint len) {
    for (uint i = 0; i < len; i++) {
        if (mat[i] < eps) {
            mat[i] = eps;
        }
        mat[i] = log(mat[i]);
    }
}

valueType NanoVoidNormal::log_with_mask_single(valueType p, valueType eps) {
    if (p < eps) {
        p = eps;
    }
    return log(p);
}

void NanoVoidNormal::masked_fill(valueType *mat, int *mask, valueType eps, uint len) {
    for (uint i = 0; i < len; i++) {
        if (mask[i] == 1) {
            mat[i] = eps;
        }
    }
}

void NanoVoidNormal::encode_from_img(valueType ***img) {
    Coordinate2d3c c(0, 0);
    for (c.x = 0; c.x < size; c.x++) {
        for (c.y = 0; c.y < size; c.y++) {
            uint item_1 = c.to_item_c1(size, size * size);
            uint item_2 = c.to_item_c2(size, size * size);
            uint item_3 = c.to_item_c3(size, size * size);
            old_v[item_1] = img[c.x][c.y][0];
            old_v[item_2] = img[c.x][c.y][1];
            old_v[item_3] = img[c.x][c.y][2];
        }
    }

    // inv
    // uint inv_size = (uint) num_items;
    // for (uint i = 0; i < inv_size; i++) {
    //     inv.makeset(i);
    // }

//     if (debug_on)
// //        inv.check_from_dfslist(num_items);

//         // item2bucket
//         for (uint i = 0; i < inv_size; i++) {
//             item2bucket[i] = NULL;
//         }

    // hash each pixel
    // valueType vals[vals_len];
    // for (uint item = 0; item < inv_size; item++) {
    //     grab_vals(item, old_v, vals);

    //     int item_lsh = lsh.lsh(vals);
    //     uint item_pd = inv.item2pd[item];

    //     HashBucket* t_bucket = hash_t.find(item_lsh);
    //     if (t_bucket != NULL) {
    //         uint t_bucket_pd = inv.item2pd[t_bucket->p_list];
    //         uint root = inv.union_(item_pd, t_bucket_pd);
    //         t_bucket->p_list = inv.d_item(root);
    //         item2bucket[t_bucket->p_list] = t_bucket;
    //         merge_neighbor_into_n_list(item, t_bucket);
    //     } else {
    //         t_bucket = new HashBucket;
    //         t_bucket->lsh_hash_code = item_lsh;
    //         t_bucket->hash_code = hash_t.hash_from_lsh(item_lsh);
    //         t_bucket->p_list = item;
    //         item2bucket[t_bucket->p_list] = t_bucket;
    //         assert((t_bucket->n_list).size() == 0);
    //         merge_neighbor_into_n_list(item, t_bucket);
    //         hash_t.insert(t_bucket, t_bucket->hash_code);
    //     }

    //     if (item % 1 == 0) {
    //         if (debug_on)
    //             printf("debugging info, item %u\n", item);
    //     }
    // }
}

valueType *** NanoVoidNormal::decode_to_img() {
    Coordinate2d3c c(0, 0);

    valueType*** mtx = new valueType**[size];
    for (c.x = 0; c.x < size; c.x++) {
        valueType** row = new valueType* [size];
        for (c.y = 0; c.y < size; c.y++) {
            uint item = c.to_item_c1(size, num_items);
            // uint item_pd = inv.item2pd[item];
            // uint root = inv.find_(item_pd);
            // uint root_item = inv.d_item(root);
            valueType* channel_arr = new valueType[3];
            for (uint channel = 0; channel < 3; channel++) {
                channel_arr[channel] = old_v[item + num_items * channel];
            }
            row[c.y] = channel_arr;
        }
        mtx[c.x] = row;
    }
    return mtx;
}

void NanoVoidNormal::next() {
    // process each pixel i,j from 0 to size-1
    valueType vals[this->vals_len];
    for (int i = 0; i < this->size; i++) {
        for (int j = 0; j < this->size; j++) {
            grab_vals(i * this->size + j, old_v, vals);
            forward_one_step(vals, i * this->size + j, new_v);
        }
    }
    std::memcpy(old_v, new_v, this->value_table_size * sizeof(valueType));
}

const int NanoVoidNormal::dx[] = {0, 1, 0,-1, 0, 1,-1, 1,-1, 2, 0,-2, 0};
const int NanoVoidNormal::dy[] = {0, 0, 1, 0,-1, 1, 1,-1,-1, 0, 2, 0,-2};
const valueType NanoVoidNormal::laplapw[] = {20,-8,-8,-8,-8,2,2,2,2,1,1,1,1};
const valueType NanoVoidNormal::lapw[] = {-4, 1, 1, 1, 1};